#encoding=utf-8
import torch.nn as nn

import torch.nn.functional as F
import torch


# Pytorch 1.1后，one_hot可以直接用torch.nn.functional.one_hot


if __name__ == '__main__':
    tensor = torch.arange(0,5) % 3  # tensor([0, 1, 2, 0, 1])
    one_hot = F.one_hot(tensor)
    print(one_hot)
    # tensor([[1, 0, 0],
    #         [0, 1, 0],
    #         [0, 0, 1],
    #         [1, 0, 0],
    #         [0, 1, 0]])
    # F.one_hot会自己检测不同类别个数，生成对应独热编码。我们也可以自己指定类别数：
    # 指定num_classes的版本
    one_hot = F.one_hot(tensor,num_classes=5)
    print(one_hot)


